mod aggregation;
mod alias;
mod apply;
mod binary;
mod cast;
mod column;
mod count;
mod element;
mod eval;
mod filter;
mod gather;
mod group_iter;
mod literal;
#[cfg(feature = "dynamic_group_by")]
mod rolling;
mod slice;
mod sort;
mod sortby;
mod ternary;
mod window;

use std::borrow::Cow;
use std::fmt::{Display, Formatter};

pub(crate) use aggregation::*;
pub(crate) use alias::*;
pub(crate) use apply::*;
use arrow::array::ArrayRef;
use arrow::bitmap::MutableBitmap;
use arrow::legacy::utils::CustomIterTools;
pub(crate) use binary::*;
pub(crate) use cast::*;
pub(crate) use column::*;
pub(crate) use count::*;
pub(crate) use element::*;
pub(crate) use eval::*;
pub(crate) use filter::*;
pub(crate) use gather::*;
pub(crate) use literal::*;
use polars_core::prelude::*;
use polars_io::predicates::PhysicalIoExpr;
use polars_plan::prelude::*;
#[cfg(feature = "dynamic_group_by")]
pub(crate) use rolling::RollingExpr;
pub(crate) use slice::*;
pub(crate) use sort::*;
pub(crate) use sortby::*;
pub(crate) use ternary::*;
pub use window::window_function_format_order_by;
pub(crate) use window::*;

use crate::state::ExecutionState;

#[derive(Clone, Debug)]
pub enum AggState {
    /// Already aggregated: `.agg_list(group_tuples)` is called
    /// and produced a `Series` of dtype `List`
    AggregatedList(Column),
    /// Already aggregated: `.agg` is called on an aggregation
    /// that produces a scalar.
    /// think of `sum`, `mean`, `variance` like aggregations.
    AggregatedScalar(Column),
    /// Not yet aggregated: `agg_list` still has to be called.
    NotAggregated(Column),
    /// A literal scalar value.
    LiteralScalar(Column),
}

impl AggState {
    fn try_map<F>(&self, func: F) -> PolarsResult<Self>
    where
        F: FnOnce(&Column) -> PolarsResult<Column>,
    {
        Ok(match self {
            AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
            AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
            AggState::LiteralScalar(c) => AggState::LiteralScalar(func(c)?),
            AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
        })
    }

    fn is_scalar(&self) -> bool {
        matches!(self, Self::AggregatedScalar(_))
    }

    pub fn name(&self) -> &PlSmallStr {
        match self {
            AggState::AggregatedList(s)
            | AggState::NotAggregated(s)
            | AggState::LiteralScalar(s)
            | AggState::AggregatedScalar(s) => s.name(),
        }
    }

    pub fn flat_dtype(&self) -> &DataType {
        match self {
            AggState::AggregatedList(s) => s.dtype().inner_dtype().unwrap(),
            AggState::NotAggregated(s)
            | AggState::LiteralScalar(s)
            | AggState::AggregatedScalar(s) => s.dtype(),
        }
    }
}

// lazy update strategy
#[derive(Debug, PartialEq, Clone, Copy)]
pub(crate) enum UpdateGroups {
    /// don't update groups
    No,
    /// use the length of the current groups to determine new sorted indexes, preferred
    /// for performance
    WithGroupsLen,
    /// use the series list offsets to determine the new group lengths
    /// this one should be used when the length has changed. Note that
    /// the series should be aggregated state or else it will panic.
    WithSeriesLen,
}

#[cfg_attr(debug_assertions, derive(Debug))]
pub struct AggregationContext<'a> {
    /// Can be in one of two states
    /// 1. already aggregated as list
    /// 2. flat (still needs the grouptuples to aggregate)
    ///
    /// When aggregation state is LiteralScalar or AggregatedScalar, the group values are not
    /// related to the state data anymore. The number of groups is still accurate.
    pub(crate) state: AggState,
    /// group tuples for AggState
    pub(crate) groups: Cow<'a, GroupPositions>,
    /// This is used to determined if we need to update the groups
    /// into a sorted groups. We do this lazily, so that this work only is
    /// done when the groups are needed
    pub(crate) update_groups: UpdateGroups,
    /// This is true when the Series and Groups still have all
    /// their original values. Not the case when filtered
    pub(crate) original_len: bool,
}

impl<'a> AggregationContext<'a> {
    pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
        match self.update_groups {
            UpdateGroups::No => {},
            UpdateGroups::WithGroupsLen => {
                // the groups are unordered
                // and the series is aggregated with this groups
                // so we need to recreate new grouptuples that
                // match the exploded Series
                let mut offset = 0 as IdxSize;

                match self.groups.as_ref().as_ref() {
                    GroupsType::Idx(groups) => {
                        let groups = groups
                            .iter()
                            .map(|g| {
                                let len = g.1.len() as IdxSize;
                                let new_offset = offset + len;
                                let out = [offset, len];
                                offset = new_offset;
                                out
                            })
                            .collect();
                        self.groups =
                            Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())
                    },
                    // sliced groups are already in correct order,
                    // Update offsets in the case of overlapping groups
                    // e.g. [0,2], [1,3], [2,4] becomes [0,2], [2,3], [5,4]
                    GroupsType::Slice { groups, .. } => {
                        // unroll
                        let groups = groups
                            .iter()
                            .map(|g| {
                                let len = g[1];
                                let new = [offset, g[1]];
                                offset += len;
                                new
                            })
                            .collect();
                        self.groups =
                            Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())
                    },
                }
                self.update_groups = UpdateGroups::No;
            },
            UpdateGroups::WithSeriesLen => {
                let s = self.get_values().clone();
                self.det_groups_from_list(s.as_materialized_series());
            },
        }
        &self.groups
    }

    pub(crate) fn get_values(&self) -> &Column {
        match &self.state {
            AggState::NotAggregated(s)
            | AggState::AggregatedScalar(s)
            | AggState::AggregatedList(s) => s,
            AggState::LiteralScalar(s) => s,
        }
    }

    pub fn agg_state(&self) -> &AggState {
        &self.state
    }

    pub(crate) fn is_not_aggregated(&self) -> bool {
        matches!(
            &self.state,
            AggState::NotAggregated(_) | AggState::LiteralScalar(_)
        )
    }

    pub(crate) fn is_aggregated(&self) -> bool {
        !self.is_not_aggregated()
    }

    pub(crate) fn is_literal(&self) -> bool {
        matches!(self.state, AggState::LiteralScalar(_))
    }

    /// # Arguments
    /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
    ///   the columns dtype)
    fn new(
        column: Column,
        groups: Cow<'a, GroupPositions>,
        aggregated: bool,
    ) -> AggregationContext<'a> {
        let series = if aggregated {
            assert_eq!(column.len(), groups.len());
            AggState::AggregatedScalar(column)
        } else {
            AggState::NotAggregated(column)
        };

        Self {
            state: series,
            groups,
            update_groups: UpdateGroups::No,
            original_len: true,
        }
    }

    fn with_agg_state(&mut self, agg_state: AggState) {
        self.state = agg_state;
    }

    fn from_agg_state(
        agg_state: AggState,
        groups: Cow<'a, GroupPositions>,
    ) -> AggregationContext<'a> {
        Self {
            state: agg_state,
            groups,
            update_groups: UpdateGroups::No,
            original_len: true,
        }
    }

    pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
        self.original_len = original_len;
        self
    }

    pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
        self.update_groups = update;
        self
    }

    fn det_groups_from_list(&mut self, s: &Series) {
        let mut offset = 0 as IdxSize;
        let list = s
            .list()
            .expect("impl error, should be a list at this point");

        match list.chunks().len() {
            1 => {
                let arr = list.downcast_iter().next().unwrap();
                let offsets = arr.offsets().as_slice();

                let mut previous = 0i64;
                let groups = offsets[1..]
                    .iter()
                    .map(|&o| {
                        let len = (o - previous) as IdxSize;
                        let new_offset = offset + len;

                        previous = o;
                        let out = [offset, len];
                        offset = new_offset;
                        out
                    })
                    .collect_trusted();
                self.groups =
                    Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());
            },
            _ => {
                let groups = {
                    self.get_values()
                        .list()
                        .expect("impl error, should be a list at this point")
                        .amortized_iter()
                        .map(|s| {
                            if let Some(s) = s {
                                let len = s.as_ref().len() as IdxSize;
                                let new_offset = offset + len;
                                let out = [offset, len];
                                offset = new_offset;
                                out
                            } else {
                                [offset, 0]
                            }
                        })
                        .collect_trusted()
                };
                self.groups =
                    Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());
            },
        }
        self.update_groups = UpdateGroups::No;
    }

    /// # Arguments
    /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
    ///   the columns dtype)
    pub(crate) fn with_values(
        &mut self,
        column: Column,
        aggregated: bool,
        expr: Option<&Expr>,
    ) -> PolarsResult<&mut Self> {
        self.with_values_and_args(
            column,
            aggregated,
            expr,
            false,
            self.agg_state().is_scalar(),
        )
    }

    pub(crate) fn with_values_and_args(
        &mut self,
        column: Column,
        aggregated: bool,
        expr: Option<&Expr>,
        // if the applied function was a `map` instead of an `apply`
        // this will keep functions applied over literals as literals: F(lit) = lit
        preserve_literal: bool,
        returns_scalar: bool,
    ) -> PolarsResult<&mut Self> {
        self.state = match (aggregated, column.dtype()) {
            (true, &DataType::List(_)) if !returns_scalar => {
                if column.len() != self.groups.len() {
                    let fmt_expr = if let Some(e) = expr {
                        format!("'{e:?}' ")
                    } else {
                        String::new()
                    };
                    polars_bail!(
                        ComputeError:
                        "aggregation expression '{}' produced a different number of elements: {} \
                        than the number of groups: {} (this is likely invalid)",
                        fmt_expr, column.len(), self.groups.len(),
                    );
                }
                AggState::AggregatedList(column)
            },
            (true, _) => AggState::AggregatedScalar(column),
            _ => {
                match self.state {
                    // already aggregated to sum, min even this series was flattened it never could
                    // retrieve the length before grouping, so it stays  in this state.
                    AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
                    // applying a function on a literal, keeps the literal state
                    AggState::LiteralScalar(_) if column.len() == 1 && preserve_literal => {
                        AggState::LiteralScalar(column)
                    },
                    _ => AggState::NotAggregated(column.into_column()),
                }
            },
        };
        Ok(self)
    }

    pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
        self.state = AggState::LiteralScalar(column);
        self
    }

    /// Update the group tuples
    pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
        if let AggState::AggregatedList(_) = self.agg_state() {
            // In case of new groups, a series always needs to be flattened
            self.with_values(self.flat_naive().into_owned(), false, None)
                .unwrap();
        }
        self.groups = Cow::Owned(groups);
        // make sure that previous setting is not used
        self.update_groups = UpdateGroups::No;
        self
    }

    /// Ensure that each group is represented by contiguous values in memory.
    pub fn normalize_values(&mut self) {
        self.set_original_len(false);
        self.groups();
        let values = self.flat_naive();
        let values = unsafe { values.agg_list(&self.groups) };
        self.state = AggState::AggregatedList(values);
        self.with_update_groups(UpdateGroups::WithGroupsLen);
    }

    /// Aggregate into `ListChunked`.
    pub fn aggregated_as_list<'b>(&'b mut self) -> Cow<'b, ListChunked> {
        self.aggregated();
        let out = self.get_values();
        match self.agg_state() {
            AggState::AggregatedScalar(_) => Cow::Owned(out.as_list()),
            _ => Cow::Borrowed(out.list().unwrap()),
        }
    }

    /// Get the aggregated version of the series.
    pub fn aggregated(&mut self) -> Column {
        // we clone, because we only want to call `self.groups()` if needed.
        // self groups may instantiate new groups and thus can be expensive.
        match self.state.clone() {
            AggState::NotAggregated(s) => {
                // The groups are determined lazily and in case of a flat/non-aggregated
                // series we use the groups to aggregate the list
                // because this is lazy, we first must to update the groups
                // by calling .groups()
                self.groups();
                #[cfg(debug_assertions)]
                {
                    if self.groups.len() > s.len() {
                        polars_warn!(
                            "groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"
                        )
                    }
                }

                // SAFETY:
                // groups are in bounds
                let out = unsafe { s.agg_list(&self.groups) };
                self.state = AggState::AggregatedList(out.clone());

                self.update_groups = UpdateGroups::WithGroupsLen;
                out
            },
            AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
            AggState::LiteralScalar(s) => {
                let rows = self.groups.len();
                let s = s.implode().unwrap();
                let s = s.new_from_index(0, rows);
                let s = s.into_column();
                self.state = AggState::AggregatedList(s.clone());
                self.with_update_groups(UpdateGroups::WithSeriesLen);
                s.clone()
            },
        }
    }

    /// Get the final aggregated version of the series.
    pub fn finalize(&mut self) -> Column {
        // we clone, because we only want to call `self.groups()` if needed.
        // self groups may instantiate new groups and thus can be expensive.
        match &self.state {
            AggState::LiteralScalar(c) => {
                let c = c.clone();
                self.groups();
                let rows = self.groups.len();
                c.new_from_index(0, rows)
            },
            _ => self.aggregated(),
        }
    }

    // If a binary or ternary function has both of these branches true, it should
    // flatten the list
    fn arity_should_explode(&self) -> bool {
        use AggState::*;
        match self.agg_state() {
            LiteralScalar(s) => s.len() == 1,
            AggregatedScalar(_) => true,
            _ => false,
        }
    }

    pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
        let _ = self.groups();
        let groups = self.groups;
        match self.state {
            AggState::NotAggregated(c) => (c, groups),
            AggState::AggregatedScalar(c) => (c, groups),
            AggState::LiteralScalar(c) => (c, groups),
            AggState::AggregatedList(c) => {
                let flattened = c
                    .explode(ExplodeOptions {
                        empty_as_null: false,
                        keep_nulls: true,
                    })
                    .unwrap();
                let groups = groups.into_owned();
                // unroll the possible flattened state
                // say we have groups with overlapping windows:
                //
                // offset, len
                // 0, 1
                // 0, 2
                // 0, 4
                //
                // gets aggregation
                //
                // [0]
                // [0, 1],
                // [0, 1, 2, 3]
                //
                // before aggregation the column was
                // [0, 1, 2, 3]
                // but explode on this list yields
                // [0, 0, 1, 0, 1, 2, 3]
                //
                // so we unroll the groups as
                //
                // [0, 1]
                // [1, 2]
                // [3, 4]
                let groups = groups.unroll();
                (flattened, Cow::Owned(groups))
            },
        }
    }

    /// Get the not-aggregated version of the series.
    /// Note that we call it naive, because if a previous expr
    /// has filtered or sorted this, this information is in the
    /// group tuples not the flattened series.
    pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
        match &self.state {
            AggState::NotAggregated(c) => Cow::Borrowed(c),
            AggState::AggregatedList(c) => {
                if cfg!(debug_assertions) {
                    // Warning, so we find cases where we accidentally explode overlapping groups
                    // We don't want this as this can create a lot of data
                    if self.groups.is_overlapping() {
                        polars_warn!(
                            "performance - an aggregated list with overlapping groups may consume excessive memory"
                        )
                    }
                }

                // We should not insert nulls, otherwise the offsets in the groups will not be correct.
                Cow::Owned(
                    c.explode(ExplodeOptions {
                        empty_as_null: false,
                        keep_nulls: true,
                    })
                    .unwrap(),
                )
            },
            AggState::AggregatedScalar(c) => Cow::Borrowed(c),
            AggState::LiteralScalar(c) => Cow::Borrowed(c),
        }
    }

    fn flat_naive_length(&self) -> usize {
        match &self.state {
            AggState::NotAggregated(c) => c.len(),
            AggState::AggregatedList(c) => c.list().unwrap().inner_length(),
            AggState::AggregatedScalar(c) => c.len(),
            AggState::LiteralScalar(_) => 1,
        }
    }

    /// Take the series.
    pub(crate) fn take(&mut self) -> Column {
        let c = match &mut self.state {
            AggState::NotAggregated(c)
            | AggState::AggregatedScalar(c)
            | AggState::AggregatedList(c) => c,
            AggState::LiteralScalar(c) => c,
        };
        std::mem::take(c)
    }

    /// Do the group indices reference all values in the aggregation state.
    fn groups_cover_all_values(&mut self) -> bool {
        if matches!(
            self.state,
            AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)
        ) {
            return true;
        }

        let num_values = self.flat_naive_length();
        match self.groups().as_ref().as_ref() {
            GroupsType::Idx(groups) => {
                let mut seen = MutableBitmap::from_len_zeroed(num_values);
                for (_, g) in groups {
                    for i in g.iter() {
                        unsafe { seen.set_unchecked(*i as usize, true) };
                    }
                }
                seen.unset_bits() == 0
            },
            GroupsType::Slice {
                groups,
                overlapping: true,
                monotonic: _,
            } => {
                // @NOTE: Slice groups are sorted by their `start` value.
                let mut offset = 0;
                let mut covers_all = true;
                for [start, length] in groups {
                    covers_all &= *start <= offset;
                    offset = start + length;
                }
                covers_all && offset == num_values as IdxSize
            },

            // If we don't have overlapping data, we can just do a count.
            GroupsType::Slice {
                groups,
                overlapping: false,
                monotonic: _,
            } => groups.iter().map(|[_, l]| *l as usize).sum::<usize>() == num_values,
        }
    }

    /// Fixes groups for `AggregatedScalar` and `LiteralScalar` so that they point to valid
    /// data elements in the `AggState` values.
    fn set_groups_for_undefined_agg_states(&mut self) {
        match &self.state {
            AggState::AggregatedList(_) | AggState::NotAggregated(_) => {},
            AggState::AggregatedScalar(c) => {
                assert_eq!(self.update_groups, UpdateGroups::No);
                self.groups = Cow::Owned({
                    let groups = (0..c.len() as IdxSize).map(|i| [i, 1]).collect();
                    GroupsType::new_slice(groups, false, true).into_sliceable()
                });
            },
            AggState::LiteralScalar(c) => {
                assert_eq!(c.len(), 1);
                assert_eq!(self.update_groups, UpdateGroups::No);
                self.groups = Cow::Owned({
                    let groups = vec![[0, 1]; self.groups.len()];
                    GroupsType::new_slice(groups, true, true).into_sliceable()
                });
            },
        }
    }

    pub fn into_static(&self) -> AggregationContext<'static> {
        let groups: GroupPositions = GroupPositions::to_owned(&self.groups);
        let groups: Cow<'static, GroupPositions> = Cow::Owned(groups);
        AggregationContext {
            state: self.state.clone(),
            groups,
            update_groups: self.update_groups,
            original_len: self.original_len,
        }
    }
}

/// Take a DataFrame and evaluate the expressions.
/// Implement this for Column, lt, eq, etc
pub trait PhysicalExpr: Send + Sync {
    fn as_expression(&self) -> Option<&Expr> {
        None
    }

    fn as_column(&self) -> Option<PlSmallStr> {
        None
    }

    /// Take a DataFrame and evaluate the expression.
    fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;

    /// Some expression that are not aggregations can be done per group
    /// Think of sort, slice, filter, shift, etc.
    /// defaults to ignoring the group
    ///
    /// This method is called by an aggregation function.
    ///
    /// In case of a simple expr, like 'column', the groups are ignored and the column is returned.
    /// In case of an expr where group behavior makes sense, this method is called.
    /// For a filter operation for instance, a Series is created per groups and filtered.
    ///
    /// An implementation of this method may apply an aggregation on the groups only. For instance
    /// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per
    /// group. The implementation then has to return the `Series` exploded (because a later aggregation
    /// will use the group tuples to aggregate). The group tuples also have to be updated, because
    /// aggregation to a list sorts the exploded `Series` by group.
    ///
    /// This has some gotcha's. An implementation may also change the group tuples instead of
    /// the `Series`.
    ///
    // we allow this because we pass the vec to the Cow
    // Note to self: Don't be smart and dispatch to evaluate as default implementation
    // this means filters will be incorrect and lead to invalid results down the line
    #[allow(clippy::ptr_arg)]
    fn evaluate_on_groups<'a>(
        &self,
        df: &DataFrame,
        groups: &'a GroupPositions,
        state: &ExecutionState,
    ) -> PolarsResult<AggregationContext<'a>>;

    /// Get the output field of this expr
    fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;

    fn is_literal(&self) -> bool {
        false
    }
    fn is_scalar(&self) -> bool;
}

impl Display for &dyn PhysicalExpr {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self.as_expression() {
            None => Ok(()),
            Some(e) => write!(f, "{e:?}"),
        }
    }
}

/// Wrapper struct that allow us to use a PhysicalExpr in polars-io.
///
/// This is used to filter rows during the scan of file.
pub struct PhysicalIoHelper {
    pub expr: Arc<dyn PhysicalExpr>,
    pub has_window_function: bool,
}

impl PhysicalIoExpr for PhysicalIoHelper {
    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
        let mut state: ExecutionState = Default::default();
        if self.has_window_function {
            state.insert_has_window_function_flag();
        }
        self.expr.evaluate(df, &state).map(|c| {
            // IO expression result should be boolean-typed.
            debug_assert_eq!(c.dtype(), &DataType::Boolean);
            (if c.len() == 1 && df.height() != 1 {
                // filter(lit(True)) will hit here.
                c.new_from_index(0, df.height())
            } else {
                c
            })
            .take_materialized_series()
        })
    }
}

pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
    let has_window_function = if let Some(expr) = expr.as_expression() {
        expr.into_iter().any(|expr| {
            #[cfg(feature = "dynamic_group_by")]
            if matches!(expr, Expr::Rolling { .. }) {
                return true;
            }

            matches!(expr, Expr::Over { .. })
        })
    } else {
        false
    };
    Arc::new(PhysicalIoHelper {
        expr,
        has_window_function,
    }) as Arc<dyn PhysicalIoExpr>
}
